{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一个温度计没有单位，我们将以我们喜欢的单位建立一个读数和相应的温度值的数据集，选择一个模型，迭代地调整它的权重，直到误差的度量足够低，最终能够以我们理解的单位来解释新的读数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_c = [0.5, 14.0, 15.0, 28.0, 11.0, 8.0, 3.0, -4.0, 6.0, 13.0, 21.0]\n",
    "t_u = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]\n",
    "t_c = torch.tensor(t_c)\n",
    "t_u = torch.tensor(t_u)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选择一个线性模型\n",
    "def model(t_u,w,b):\n",
    "    return w * t_u + b\n",
    "\n",
    "def loss_fn(tp,tc):\n",
    "    squared_diffs = (tp-tc)**2\n",
    "    return squared_diffs.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([35.7000, 55.9000, 58.2000, 81.9000, 56.3000, 48.9000, 33.9000, 21.8000,\n",
       "        48.4000, 60.4000, 68.4000])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 初始化参数\n",
    "w = torch.ones(())\n",
    "b = torch.zeros(())\n",
    "\n",
    "t_p = model(t_u,w,b)\n",
    "t_p"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(1763.8846)"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss = loss_fn(t_p,t_c)\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta = 0.1\n",
    "loss_rate_of_change_w = (loss_fn(model(t_u,w+delta,b),t_c)-loss_fn(model(t_u,w-delta,b),t_c))/(2.0*delta)\n",
    "loss_rate_of_change_b = (loss_fn(model(t_u,w,b+delta),t_c)-loss_fn(model(t_u,w,b-delta),t_c))/(2.0*delta)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr=1e-2\n",
    "w = w - lr * loss_rate_of_change_w\n",
    "b = b - lr * loss_rate_of_change_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 改用计算梯度\n",
    "def dloss_fn(t_p,t_c):\n",
    "    dsq_diffs = 2*(t_p - t_c) /t_p.size(0)\n",
    "    return dsq_diffs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dmodel_dw(t_u,w,b):\n",
    "    return t_u\n",
    "def dmodel_db(t_u,w,b):\n",
    "    return 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad_fn(t_u,t_c,t_p,w,b):\n",
    "    dloss_dtp = dloss_fn(t_p,t_c)\n",
    "    dloss_dw = dloss_dtp * dmodel_dw(t_u,w,b)\n",
    "    dloss_db = dloss_dtp * dmodel_db(t_u,w,b)\n",
    "    return torch.stack([dloss_dw.sum(),dloss_db.sum()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_loop(n_epochs,learning_rate,params,t_u,t_c):\n",
    "    for epoch in range(1,n_epochs+1):\n",
    "        w,b = params\n",
    "        t_p = model(t_u,w,b)           #前向传播\n",
    "        loss = loss_fn(t_p,t_c)\n",
    "        grad = grad_fn(t_u,t_c,t_p,w,b)\n",
    "        \n",
    "        params = params - learning_rate*grad\n",
    "        \n",
    "        print('Epoch %d,Loss %f'%(epoch,float(loss)))\n",
    "    return params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1,Loss 1763.884644\n",
      "Epoch 2,Loss 323.090546\n",
      "Epoch 3,Loss 78.929634\n",
      "Epoch 4,Loss 37.552845\n",
      "Epoch 5,Loss 30.540285\n",
      "Epoch 6,Loss 29.351152\n",
      "Epoch 7,Loss 29.148882\n",
      "Epoch 8,Loss 29.113848\n",
      "Epoch 9,Loss 29.107145\n",
      "Epoch 10,Loss 29.105242\n",
      "Epoch 11,Loss 29.104168\n",
      "Epoch 12,Loss 29.103222\n",
      "Epoch 13,Loss 29.102297\n",
      "Epoch 14,Loss 29.101379\n",
      "Epoch 15,Loss 29.100470\n",
      "Epoch 16,Loss 29.099548\n",
      "Epoch 17,Loss 29.098631\n",
      "Epoch 18,Loss 29.097715\n",
      "Epoch 19,Loss 29.096796\n",
      "Epoch 20,Loss 29.095884\n",
      "Epoch 21,Loss 29.094959\n",
      "Epoch 22,Loss 29.094049\n",
      "Epoch 23,Loss 29.093134\n",
      "Epoch 24,Loss 29.092213\n",
      "Epoch 25,Loss 29.091297\n",
      "Epoch 26,Loss 29.090382\n",
      "Epoch 27,Loss 29.089460\n",
      "Epoch 28,Loss 29.088549\n",
      "Epoch 29,Loss 29.087635\n",
      "Epoch 30,Loss 29.086718\n",
      "Epoch 31,Loss 29.085808\n",
      "Epoch 32,Loss 29.084888\n",
      "Epoch 33,Loss 29.083965\n",
      "Epoch 34,Loss 29.083057\n",
      "Epoch 35,Loss 29.082142\n",
      "Epoch 36,Loss 29.081219\n",
      "Epoch 37,Loss 29.080309\n",
      "Epoch 38,Loss 29.079393\n",
      "Epoch 39,Loss 29.078474\n",
      "Epoch 40,Loss 29.077559\n",
      "Epoch 41,Loss 29.076653\n",
      "Epoch 42,Loss 29.075731\n",
      "Epoch 43,Loss 29.074812\n",
      "Epoch 44,Loss 29.073896\n",
      "Epoch 45,Loss 29.072985\n",
      "Epoch 46,Loss 29.072069\n",
      "Epoch 47,Loss 29.071148\n",
      "Epoch 48,Loss 29.070234\n",
      "Epoch 49,Loss 29.069323\n",
      "Epoch 50,Loss 29.068401\n",
      "Epoch 51,Loss 29.067486\n",
      "Epoch 52,Loss 29.066570\n",
      "Epoch 53,Loss 29.065655\n",
      "Epoch 54,Loss 29.064739\n",
      "Epoch 55,Loss 29.063829\n",
      "Epoch 56,Loss 29.062910\n",
      "Epoch 57,Loss 29.061989\n",
      "Epoch 58,Loss 29.061079\n",
      "Epoch 59,Loss 29.060169\n",
      "Epoch 60,Loss 29.059252\n",
      "Epoch 61,Loss 29.058332\n",
      "Epoch 62,Loss 29.057417\n",
      "Epoch 63,Loss 29.056507\n",
      "Epoch 64,Loss 29.055586\n",
      "Epoch 65,Loss 29.054670\n",
      "Epoch 66,Loss 29.053761\n",
      "Epoch 67,Loss 29.052843\n",
      "Epoch 68,Loss 29.051929\n",
      "Epoch 69,Loss 29.051014\n",
      "Epoch 70,Loss 29.050098\n",
      "Epoch 71,Loss 29.049183\n",
      "Epoch 72,Loss 29.048271\n",
      "Epoch 73,Loss 29.047346\n",
      "Epoch 74,Loss 29.046442\n",
      "Epoch 75,Loss 29.045530\n",
      "Epoch 76,Loss 29.044611\n",
      "Epoch 77,Loss 29.043699\n",
      "Epoch 78,Loss 29.042780\n",
      "Epoch 79,Loss 29.041870\n",
      "Epoch 80,Loss 29.040955\n",
      "Epoch 81,Loss 29.040039\n",
      "Epoch 82,Loss 29.039122\n",
      "Epoch 83,Loss 29.038214\n",
      "Epoch 84,Loss 29.037292\n",
      "Epoch 85,Loss 29.036379\n",
      "Epoch 86,Loss 29.035463\n",
      "Epoch 87,Loss 29.034557\n",
      "Epoch 88,Loss 29.033636\n",
      "Epoch 89,Loss 29.032721\n",
      "Epoch 90,Loss 29.031805\n",
      "Epoch 91,Loss 29.030895\n",
      "Epoch 92,Loss 29.029976\n",
      "Epoch 93,Loss 29.029066\n",
      "Epoch 94,Loss 29.028151\n",
      "Epoch 95,Loss 29.027235\n",
      "Epoch 96,Loss 29.026323\n",
      "Epoch 97,Loss 29.025410\n",
      "Epoch 98,Loss 29.024494\n",
      "Epoch 99,Loss 29.023582\n",
      "Epoch 100,Loss 29.022669\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([ 0.2327, -0.0438])"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "training_loop(n_epochs=100,learning_rate=1e-4,params=torch.tensor([1.0,0.0]),\n",
    "             t_u = t_u,\n",
    "             t_c = t_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1,Loss 80.364342\n",
      "Epoch 2,Loss 37.574917\n",
      "Epoch 3,Loss 30.871077\n",
      "Epoch 4,Loss 29.756193\n",
      "Epoch 5,Loss 29.507149\n",
      "Epoch 6,Loss 29.392458\n",
      "Epoch 7,Loss 29.298828\n",
      "Epoch 8,Loss 29.208717\n",
      "Epoch 9,Loss 29.119417\n",
      "Epoch 10,Loss 29.030487\n",
      "Epoch 11,Loss 28.941875\n",
      "Epoch 12,Loss 28.853565\n",
      "Epoch 13,Loss 28.765556\n",
      "Epoch 14,Loss 28.677851\n",
      "Epoch 15,Loss 28.590431\n",
      "Epoch 16,Loss 28.503321\n",
      "Epoch 17,Loss 28.416496\n",
      "Epoch 18,Loss 28.329975\n",
      "Epoch 19,Loss 28.243738\n",
      "Epoch 20,Loss 28.157801\n",
      "Epoch 21,Loss 28.072151\n",
      "Epoch 22,Loss 27.986799\n",
      "Epoch 23,Loss 27.901731\n",
      "Epoch 24,Loss 27.816954\n",
      "Epoch 25,Loss 27.732460\n",
      "Epoch 26,Loss 27.648256\n",
      "Epoch 27,Loss 27.564342\n",
      "Epoch 28,Loss 27.480711\n",
      "Epoch 29,Loss 27.397358\n",
      "Epoch 30,Loss 27.314295\n",
      "Epoch 31,Loss 27.231512\n",
      "Epoch 32,Loss 27.149006\n",
      "Epoch 33,Loss 27.066790\n",
      "Epoch 34,Loss 26.984844\n",
      "Epoch 35,Loss 26.903173\n",
      "Epoch 36,Loss 26.821791\n",
      "Epoch 37,Loss 26.740675\n",
      "Epoch 38,Loss 26.659838\n",
      "Epoch 39,Loss 26.579279\n",
      "Epoch 40,Loss 26.498987\n",
      "Epoch 41,Loss 26.418974\n",
      "Epoch 42,Loss 26.339228\n",
      "Epoch 43,Loss 26.259752\n",
      "Epoch 44,Loss 26.180548\n",
      "Epoch 45,Loss 26.101616\n",
      "Epoch 46,Loss 26.022949\n",
      "Epoch 47,Loss 25.944542\n",
      "Epoch 48,Loss 25.866417\n",
      "Epoch 49,Loss 25.788546\n",
      "Epoch 50,Loss 25.710936\n",
      "Epoch 51,Loss 25.633600\n",
      "Epoch 52,Loss 25.556524\n",
      "Epoch 53,Loss 25.479700\n",
      "Epoch 54,Loss 25.403149\n",
      "Epoch 55,Loss 25.326851\n",
      "Epoch 56,Loss 25.250811\n",
      "Epoch 57,Loss 25.175035\n",
      "Epoch 58,Loss 25.099510\n",
      "Epoch 59,Loss 25.024248\n",
      "Epoch 60,Loss 24.949238\n",
      "Epoch 61,Loss 24.874483\n",
      "Epoch 62,Loss 24.799980\n",
      "Epoch 63,Loss 24.725737\n",
      "Epoch 64,Loss 24.651735\n",
      "Epoch 65,Loss 24.577986\n",
      "Epoch 66,Loss 24.504494\n",
      "Epoch 67,Loss 24.431250\n",
      "Epoch 68,Loss 24.358257\n",
      "Epoch 69,Loss 24.285505\n",
      "Epoch 70,Loss 24.212996\n",
      "Epoch 71,Loss 24.140741\n",
      "Epoch 72,Loss 24.068733\n",
      "Epoch 73,Loss 23.996967\n",
      "Epoch 74,Loss 23.925446\n",
      "Epoch 75,Loss 23.854168\n",
      "Epoch 76,Loss 23.783125\n",
      "Epoch 77,Loss 23.712328\n",
      "Epoch 78,Loss 23.641771\n",
      "Epoch 79,Loss 23.571455\n",
      "Epoch 80,Loss 23.501379\n",
      "Epoch 81,Loss 23.431538\n",
      "Epoch 82,Loss 23.361938\n",
      "Epoch 83,Loss 23.292570\n",
      "Epoch 84,Loss 23.223436\n",
      "Epoch 85,Loss 23.154539\n",
      "Epoch 86,Loss 23.085882\n",
      "Epoch 87,Loss 23.017447\n",
      "Epoch 88,Loss 22.949251\n",
      "Epoch 89,Loss 22.881283\n",
      "Epoch 90,Loss 22.813547\n",
      "Epoch 91,Loss 22.746044\n",
      "Epoch 92,Loss 22.678770\n",
      "Epoch 93,Loss 22.611717\n",
      "Epoch 94,Loss 22.544899\n",
      "Epoch 95,Loss 22.478304\n",
      "Epoch 96,Loss 22.411938\n",
      "Epoch 97,Loss 22.345795\n",
      "Epoch 98,Loss 22.279875\n",
      "Epoch 99,Loss 22.214186\n",
      "Epoch 100,Loss 22.148710\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([ 2.7553, -2.5162])"
      ]
     },
     "execution_count": 89,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 为解决lr不能同时兼顾两个参数的问题（w和b差别较大），可以改变输入，即标准化\n",
    "t_un = 0.1*t_u\n",
    "training_loop(n_epochs=100,learning_rate=1e-2,params=torch.tensor([1.0,0.0]),\n",
    "             t_u = t_un,\n",
    "             t_c = t_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
